{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bf650d66",
   "metadata": {},
   "source": [
    "# GEMM 算子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "73303d82",
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "83fd1a83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========GEMM 128=========\n",
      "----- GEMM GOPS End-to-End Test-------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[20:54:16] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64\n",
      "2025-07-27 20:54:16.770 INFO load_module /tmp/tmpklerl2r0/gemm.o\n",
      "[20:54:16] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/runtime/profiling.cc:101: Warning: No timer implementation for ext_dev, using default timer instead. It may be inaccurate or have extra overhead.\n",
      "[20:54:16] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution statistics:\n",
      "\tinp_load_nbytes :           344064\n",
      "\twgt_load_nbytes :           344064\n",
      "\tacc_load_nbytes :                0\n",
      "\tuop_load_nbytes :             1008\n",
      "\tout_store_nbytes:           344064\n",
      "\tgemm_counter    :           172032\n",
      "\talu_counter     :            64512\n",
      "NORMAL\n",
      "\tTime cost = 0.000646294 sec/op, 6.48978 GOPS\n",
      "----- GEMM Unit Test-------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-07-27 20:54:17.031 INFO load_module /tmp/tmpklerl2r0/gemm.o\n",
      "[20:54:17] /media/pc/data/board/arria10/lxw/tasks/tvm-new/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64\n",
      "2025-07-27 20:54:17.303 INFO load_module /tmp/tmpklerl2r0/gemm.o\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution statistics:\n",
      "\tinp_load_nbytes :                0\n",
      "\twgt_load_nbytes :                0\n",
      "\tacc_load_nbytes :                0\n",
      "\tuop_load_nbytes :              756\n",
      "\tout_store_nbytes:                0\n",
      "\tgemm_counter    :           172032\n",
      "\talu_counter     :                0\n",
      "NORMAL\n",
      "\tTime cost = 0.000593391 sec/op, 7.06836 GOPS\n",
      "----- ALU Unit Test-------\n",
      "Execution statistics:\n",
      "\tinp_load_nbytes :                0\n",
      "\twgt_load_nbytes :                0\n",
      "\tacc_load_nbytes :                0\n",
      "\tuop_load_nbytes :              252\n",
      "\tout_store_nbytes:                0\n",
      "\tgemm_counter    :                0\n",
      "\talu_counter     :            64512\n",
      "NORMAL\n",
      "\tTime cost = 4.4641e-05 sec/op, 93.9563 GOPS\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import te\n",
    "import numpy as np\n",
    "from tvm.contrib import utils\n",
    "import vta.testing\n",
    "from vta.testing import simulator\n",
    "\n",
    "\n",
    "def test_gemm():\n",
    "    def run_gemm_packed(env, remote, batch_size, channel, block):\n",
    "        data_shape = (batch_size // env.BATCH, channel // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)\n",
    "        weight_shape = (\n",
    "            channel // env.BLOCK_OUT,\n",
    "            channel // env.BLOCK_IN,\n",
    "            env.BLOCK_OUT,\n",
    "            env.BLOCK_IN,\n",
    "        )\n",
    "        res_shape = (batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)\n",
    "        # To compute number of ops, use a x2 factor for FMA\n",
    "        num_ops = 2 * channel * channel * batch_size\n",
    "\n",
    "        ko = te.reduce_axis((0, channel // env.BLOCK_IN), name=\"ko\")\n",
    "        ki = te.reduce_axis((0, env.BLOCK_IN), name=\"ki\")\n",
    "\n",
    "        data = te.placeholder(data_shape, name=\"data\", dtype=env.inp_dtype)\n",
    "        weight = te.placeholder(weight_shape, name=\"weight\", dtype=env.wgt_dtype)\n",
    "        data_buf = te.compute(data_shape, lambda *i: data(*i), \"data_buf\")\n",
    "        weight_buf = te.compute(weight_shape, lambda *i: weight(*i), \"weight_buf\")\n",
    "        res_gem = te.compute(\n",
    "            res_shape,\n",
    "            lambda bo, co, bi, ci: te.sum(\n",
    "                data_buf[bo, ko, bi, ki].astype(env.acc_dtype)\n",
    "                * weight_buf[co, ko, ci, ki].astype(env.acc_dtype),\n",
    "                axis=[ko, ki],\n",
    "            ),\n",
    "            name=\"res_gem\",\n",
    "        )\n",
    "        res_shf = te.compute(res_shape, lambda *i: res_gem(*i) >> 8, name=\"res_shf\")\n",
    "        res_max = te.compute(res_shape, lambda *i: tvm.te.max(res_shf(*i), 0), \"res_max\")  # relu\n",
    "        res_min = te.compute(\n",
    "            res_shape, lambda *i: tvm.te.min(res_max(*i), (1 << (env.INP_WIDTH - 1)) - 1), \"res_min\"\n",
    "        )  # relu\n",
    "        res = te.compute(res_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name=\"res\")\n",
    "\n",
    "        def verify(s):\n",
    "            mod = vta.build(\n",
    "                s,\n",
    "                [data, weight, res],\n",
    "                tvm.target.Target(\"ext_dev\", host=env.target_host),\n",
    "                name=\"gemm\",\n",
    "            )\n",
    "            temp = utils.tempdir()\n",
    "            mod.save(temp.relpath(\"gemm.o\"))\n",
    "            remote.upload(temp.relpath(\"gemm.o\"))\n",
    "            f = remote.load_module(\"gemm.o\")\n",
    "            # verify\n",
    "            dev = remote.ext_dev(0)\n",
    "            # Data in original format\n",
    "            data_orig = np.random.randint(-128, 128, size=(batch_size, channel)).astype(data.dtype)\n",
    "            weight_orig = np.random.randint(-128, 128, size=(channel, channel)).astype(weight.dtype)\n",
    "            data_packed = data_orig.reshape(\n",
    "                batch_size // env.BATCH, env.BATCH, channel // env.BLOCK_IN, env.BLOCK_IN\n",
    "            ).transpose((0, 2, 1, 3))\n",
    "            weight_packed = weight_orig.reshape(\n",
    "                channel // env.BLOCK_OUT, env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_IN\n",
    "            ).transpose((0, 2, 1, 3))\n",
    "            res_np = np.zeros(res_shape).astype(res.dtype)\n",
    "            data_arr = tvm.nd.array(data_packed, dev)\n",
    "            weight_arr = tvm.nd.array(weight_packed, dev)\n",
    "            res_arr = tvm.nd.array(res_np, dev)\n",
    "            res_ref = np.zeros(res_shape).astype(env.acc_dtype)\n",
    "            for b in range(batch_size // env.BATCH):\n",
    "                for i in range(channel // env.BLOCK_OUT):\n",
    "                    for j in range(channel // env.BLOCK_IN):\n",
    "                        res_ref[b, i, :] += np.dot(\n",
    "                            data_packed[b, j, :].astype(env.acc_dtype),\n",
    "                            weight_packed[i, j].T.astype(env.acc_dtype),\n",
    "                        )\n",
    "            res_ref = np.right_shift(res_ref, 8)\n",
    "            res_ref = np.clip(res_ref, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)\n",
    "            time_f = f.time_evaluator(\"gemm\", dev, number=20)\n",
    "            if env.TARGET in [\"sim\", \"tsim\"]:\n",
    "                simulator.clear_stats()\n",
    "            cost = time_f(data_arr, weight_arr, res_arr)\n",
    "            if env.TARGET in [\"sim\", \"tsim\"]:\n",
    "                stats = simulator.stats()\n",
    "                print(\"Execution statistics:\")\n",
    "                for k, v in stats.items():\n",
    "                    print(\"\\t{:<16}: {:>16}\".format(k, v))\n",
    "            res_unpack = res_arr.numpy().reshape(\n",
    "                batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT\n",
    "            )\n",
    "            return cost\n",
    "\n",
    "        def run_schedule(load_inp, load_wgt, gemm, alu, store_out, print_ir):\n",
    "            s = te.create_schedule(res.op)\n",
    "            s[data_buf].set_scope(env.inp_scope)\n",
    "            s[weight_buf].set_scope(env.wgt_scope)\n",
    "            s[res_gem].set_scope(env.acc_scope)\n",
    "            s[res_shf].set_scope(env.acc_scope)\n",
    "            s[res_min].set_scope(env.acc_scope)\n",
    "            s[res_max].set_scope(env.acc_scope)\n",
    "\n",
    "            if block:\n",
    "                bblock = block // env.BATCH\n",
    "                iblock = block // env.BLOCK_IN\n",
    "                oblock = block // env.BLOCK_OUT\n",
    "                xbo, xco, xbi, xci = s[res].op.axis\n",
    "                xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)\n",
    "                store_pt = xb2\n",
    "\n",
    "                s[res_gem].compute_at(s[res], xco1)\n",
    "                s[res_shf].compute_at(s[res], xco1)\n",
    "                s[res_min].compute_at(s[res], xco1)\n",
    "                s[res_max].compute_at(s[res], xco1)\n",
    "\n",
    "                xbo, xco, xbi, xci = s[res_gem].op.axis\n",
    "                # Compute one line at a time\n",
    "                ko1, ko2 = s[res_gem].split(ko, iblock)\n",
    "                s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)\n",
    "                s[data_buf].compute_at(s[res_gem], ko1)\n",
    "                s[weight_buf].compute_at(s[res_gem], ko1)\n",
    "                # Use VTA instructions\n",
    "                s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)\n",
    "                s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)\n",
    "                s[res_gem].tensorize(xbi, gemm)\n",
    "                s[res_shf].pragma(s[res_shf].op.axis[0], alu)\n",
    "                s[res_min].pragma(s[res_min].op.axis[0], alu)\n",
    "                s[res_max].pragma(s[res_max].op.axis[0], alu)\n",
    "                s[res].pragma(store_pt, store_out)\n",
    "            else:\n",
    "                xbo, xco, xbi, xci = s[res_gem].op.axis\n",
    "                s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)\n",
    "                # Use VTA instructions\n",
    "                s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)\n",
    "                s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)\n",
    "                s[res_gem].tensorize(xbi, gemm)\n",
    "                s[res_shf].pragma(s[res_shf].op.axis[0], alu)\n",
    "                s[res_min].pragma(s[res_min].op.axis[0], alu)\n",
    "                s[res_max].pragma(s[res_max].op.axis[0], alu)\n",
    "                s[res].pragma(s[res].op.axis[0], store_out)\n",
    "\n",
    "            if print_ir:\n",
    "                print(tvm.lower(s, [data, weight, res], simple_mode=True))\n",
    "            return verify(s)\n",
    "\n",
    "        def gemm_normal(print_ir):\n",
    "            mock = env.mock\n",
    "            print(\"----- GEMM GOPS End-to-End Test-------\")\n",
    "\n",
    "            def run_test(header, print_ir):\n",
    "                cost = run_schedule(\n",
    "                    env.dma_copy,\n",
    "                    env.dma_copy,\n",
    "                    env.gemm,\n",
    "                    env.alu,\n",
    "                    env.dma_copy,\n",
    "                    print_ir,\n",
    "                )\n",
    "                gops = (num_ops / cost.mean) / float(10**9)\n",
    "                print(header)\n",
    "                print(\"\\tTime cost = %g sec/op, %g GOPS\" % (cost.mean, gops))\n",
    "\n",
    "            with vta.build_config():\n",
    "                run_test(\"NORMAL\", print_ir)\n",
    "\n",
    "        def gemm_unittest(print_ir):\n",
    "            mock = env.mock\n",
    "            print(\"----- GEMM Unit Test-------\")\n",
    "\n",
    "            def run_test(header, print_ir):\n",
    "                cost = run_schedule(\n",
    "                    mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy, print_ir\n",
    "                )\n",
    "                gops = (num_ops / cost.mean) / float(10**9)\n",
    "                print(header)\n",
    "                print(\"\\tTime cost = %g sec/op, %g GOPS\" % (cost.mean, gops))\n",
    "\n",
    "            with vta.build_config():\n",
    "                run_test(\"NORMAL\", print_ir)\n",
    "\n",
    "        def alu_unittest(print_ir):\n",
    "            mock = env.mock\n",
    "            print(\"----- ALU Unit Test-------\")\n",
    "\n",
    "            def run_test(header, print_ir):\n",
    "                cost = run_schedule(\n",
    "                    mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy, print_ir\n",
    "                )\n",
    "                gops = (num_ops / cost.mean) / float(10**9)\n",
    "                print(header)\n",
    "                print(\"\\tTime cost = %g sec/op, %g GOPS\" % (cost.mean, gops))\n",
    "\n",
    "            with vta.build_config():\n",
    "                run_test(\"NORMAL\", print_ir)\n",
    "            print(\"\")\n",
    "\n",
    "        def load_inp_unittest(print_ir):\n",
    "            mock = env.mock\n",
    "            print(\"----- LoadInp Unit Test-------\")\n",
    "\n",
    "            def run_test(header, print_ir):\n",
    "                cost = run_schedule(\n",
    "                    env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir\n",
    "                )\n",
    "                gops = (num_ops / cost.mean) / float(10**9)\n",
    "                bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)\n",
    "                print(header)\n",
    "                print(\n",
    "                    \"\\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits\"\n",
    "                    % (cost.mean, gops, bandwith)\n",
    "                )\n",
    "\n",
    "            with vta.build_config():\n",
    "                run_test(\"NORMAL\", print_ir)\n",
    "            print(\"\")\n",
    "\n",
    "        def load_wgt_unittest(print_ir):\n",
    "            mock = env.mock\n",
    "            print(\"----- LoadWgt Unit Test-------\")\n",
    "\n",
    "            def run_test(header, print_ir):\n",
    "                cost = run_schedule(\n",
    "                    mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir\n",
    "                )\n",
    "                gops = (num_ops / cost.mean) / float(10**9)\n",
    "                bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10**9)\n",
    "                print(header)\n",
    "                print(\n",
    "                    \"\\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits\"\n",
    "                    % (cost.mean, gops, bandwith)\n",
    "                )\n",
    "\n",
    "            with vta.build_config():\n",
    "                run_test(\"NORMAL\", print_ir)\n",
    "            print(\"\")\n",
    "\n",
    "        def store_out_unittest(print_ir):\n",
    "            mock = env.mock\n",
    "            print(\"----- StoreOut Unit Test-------\")\n",
    "\n",
    "            def run_test(header, print_ir):\n",
    "                cost = run_schedule(\n",
    "                    mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy, print_ir\n",
    "                )\n",
    "                gops = (num_ops / cost.mean) / float(10**9)\n",
    "                bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10**9)\n",
    "                print(header)\n",
    "                print(\n",
    "                    \"\\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits\"\n",
    "                    % (cost.mean, gops, bandwith)\n",
    "                )\n",
    "\n",
    "            with vta.build_config():\n",
    "                run_test(\"NORMAL\", print_ir)\n",
    "            print(\"\")\n",
    "\n",
    "        gemm_normal(False)\n",
    "        gemm_unittest(False)\n",
    "        alu_unittest(False)\n",
    "\n",
    "    def _run(env, remote):\n",
    "        print(\"========GEMM 128=========\")\n",
    "        run_gemm_packed(env, remote, 128, 128, 128)\n",
    "\n",
    "    vta.testing.run(_run)\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    test_gemm()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0b5d46b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
